//+------------------------------------------------------------------+
//|                                   RNN timeseries forecasting.mq5 |
//|                                     Copyright 2023, Omega Joctan |
//|                        https://www.mql5.com/en/users/omegajoctan |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, Omega Joctan"
#property link      "https://www.mql5.com/en/users/omegajoctan"
#property version   "1.00"

#resource "\\Files\\rnn.EURUSD.D1.onnx" as uchar onnx_model[]; //rnn model in onnx format
#resource "\\Files\\standard_scaler_mean.bin" as double standardization_mean[];
#resource "\\Files\\standard_scaler_scale.bin" as double standardization_std[];

#include <RNN.mqh>
CRNN rnn;

#include <preprocessing.mqh>
StandardizationScaler *scaler; //For loading the scaling technique

#include <Trade\Trade.mqh>
#include <Trade\PositionInfo.mqh>

CTrade m_trade;
CPositionInfo m_position;

MqlDateTime date_time_struct;

int ma_handle;
int stddev_handle;
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
input group "rnn";
input uint rnn_time_step = 7; 
//this value must be the same as the one used during training in a python script

input ENUM_TIMEFRAMES timeframe = PERIOD_D1;
input int magic_number = 1945;
input int slippage = 50;
input int stoploss = 500;
input int takeprofit = 700;

int OldNumBars =0;
double lotsize;

vector classes_in_data_ = {0,1}; //we have to assign the classes manually | it is very important that their order is preserved as they can be seen in python code, HINT: They are usually in ascending order
//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
  
//--- Initialize ONNX model
   
   if (!rnn.Init(onnx_model))
     return INIT_FAILED;
   
//--- Initializing the scaler with values loaded from binary files 

   scaler = new StandardizationScaler(standardization_mean, standardization_std);
   
//--- Initializing the CTrade library for executing trades

   m_trade.SetExpertMagicNumber(magic_number);
   m_trade.SetDeviationInPoints(slippage);
   m_trade.SetMarginMode();
   m_trade.SetTypeFillingBySymbol(Symbol());
           
   lotsize = SymbolInfoDouble(Symbol(), SYMBOL_VOLUME_MIN);
   
//--- Initializing the indicators

   ma_handle = iMA(Symbol(),timeframe,30,0,MODE_SMA,PRICE_WEIGHTED); //The Moving averaege for 30 days
   stddev_handle = iStdDev(Symbol(), timeframe, 7,0,MODE_SMA,PRICE_WEIGHTED); //The standard deviation for 7 days
   
//---

   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
//---
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
//---
   
   if (NewBar()) //Trade at the opening of a new candle
    {
      matrix input_data_matrix = GetInputData(rnn_time_step);
      input_data_matrix = scaler.transform(input_data_matrix); //applying StandardSCaler to the input data
      
      int signal = rnn.predict_bin(input_data_matrix, classes_in_data_); //getting trade signal from the RNN model
     
      Comment("Signal==",signal);
     
   //---
     
      MqlTick ticks;
      SymbolInfoTick(Symbol(), ticks);
      
      if (signal==1) //if the signal is bullish
       {
          if (!PosExists(POSITION_TYPE_BUY)) //There are no buy positions
           {
             if (!m_trade.Buy(lotsize, Symbol(), ticks.ask, ticks.bid-stoploss*Point(), ticks.ask+takeprofit*Point())) //Open a buy trade
               printf("Failed to open a buy position err=%d",GetLastError());
           }
       }
      else if (signal==0) //Bearish signal
        {
          if (!PosExists(POSITION_TYPE_SELL)) //There are no Sell positions
            if (!m_trade.Sell(lotsize, Symbol(), ticks.bid, ticks.ask+stoploss*Point(), ticks.bid-takeprofit*Point())) //open a sell trade
               printf("Failed to open a sell position err=%d",GetLastError());
        }
      else //There was an error
        return;
    }
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool NewBar()
  {
   int CurrentNumBars = Bars(Symbol(),Period());
   if(OldNumBars!=CurrentNumBars)
     {
      OldNumBars = CurrentNumBars;
      return true;
     }
   return false;
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool PosExists(ENUM_POSITION_TYPE type)
 {
    for (int i=PositionsTotal()-1; i>=0; i--)
      if (m_position.SelectByIndex(i))
         if (m_position.Symbol()==Symbol() && m_position.Magic() == magic_number && m_position.PositionType()==type)
            return (true);
            
    return (false);
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
matrix GetInputData(int bars, int start_bar=1)
 {
   vector open(bars), 
          high(bars),
          low(bars), 
          close(bars), 
          ma(bars), 
          stddev(bars), 
          dayofmonth(bars), 
          dayofweek(bars), 
          dayofyear(bars), 
          month(bars);

//--- Getting OHLC values
   
   open.CopyRates(Symbol(), timeframe, COPY_RATES_OPEN, start_bar, bars);
   high.CopyRates(Symbol(), timeframe, COPY_RATES_HIGH, start_bar, bars);
   low.CopyRates(Symbol(), timeframe, COPY_RATES_LOW, start_bar, bars);
   close.CopyRates(Symbol(), timeframe, COPY_RATES_CLOSE, start_bar, bars);
   
   vector time_vector;
   time_vector.CopyRates(Symbol(), timeframe, COPY_RATES_TIME, start_bar, bars);
   
//---

   
   ma.CopyIndicatorBuffer(ma_handle, 0, start_bar, bars); //getting moving avg values 
   stddev.CopyIndicatorBuffer(stddev_handle, 0, start_bar, bars); //getting standard deviation values
   
   string time = "";
   for (int i=0; i<bars; i++) //Extracting time features 
     {
       time = (string)datetime(time_vector[i]); //converting the data from seconds to date then to string
       TimeToStruct((datetime)StringToTime(time), date_time_struct); //convering the string time to date then assigning them to a structure
       
       dayofmonth[i] = date_time_struct.day;
       dayofweek[i] = date_time_struct.day_of_week;
       dayofyear[i] = date_time_struct.day_of_year;
       month[i] = date_time_struct.mon;
     }
   
   matrix data(bars, 10); //we have 10 inputs from rnn | this value is fixed
   
//--- adding the features into a data matrix
   
   data.Col(open, 0);
   data.Col(high, 1);
   data.Col(low, 2);
   data.Col(close, 3);
   data.Col(ma, 4);
   data.Col(stddev, 5);
   data.Col(dayofmonth, 6);
   data.Col(dayofweek, 7);
   data.Col(dayofyear, 8);
   data.Col(month, 9);
   
   return data;
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+


